{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Notebook for preparing and saving TSP graphs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import pickle\n",
    "import time\n",
    "import os\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Download TSP dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.isfile('TSP.zip'):\n",
    "    print('downloading..')\n",
    "    !curl https://www.dropbox.com/s/1wf6zn5nq7qjg0e/TSP.zip?dl=1 -o TSP.zip -J -L -k\n",
    "    !unzip TSP.zip -d ../\n",
    "    # !tar -xvf TSP.zip -C ../\n",
    "else:\n",
    "    print('File already downloaded')\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convert to DGL format and save with pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../../') # go to root folder of the project\n",
    "print(os.getcwd())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from data.data import LoadData\n",
    "from data.TSP import TSP, TSPDatasetDGL, TSPDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time.time()\n",
    "\n",
    "DATASET_NAME = 'TSP'\n",
    "dataset = TSPDatasetDGL(DATASET_NAME) \n",
    "\n",
    "print('Time (sec):',time.time() - start) # ~ 30 mins\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_histo_graphs(dataset, title):\n",
    "    # histogram of graph sizes\n",
    "    graph_sizes = []\n",
    "    for graph in dataset:\n",
    "        graph_sizes.append(graph[0].number_of_nodes())\n",
    "        #graph_sizes.append(graph[0].number_of_edges())\n",
    "    plt.figure(1)\n",
    "    plt.hist(graph_sizes, bins=20)\n",
    "    plt.title(title)\n",
    "    plt.show()\n",
    "    graph_sizes = torch.Tensor(graph_sizes)\n",
    "    print('nb/min/max :',len(graph_sizes),graph_sizes.min().long().item(),graph_sizes.max().long().item())\n",
    "    \n",
    "plot_histo_graphs(dataset.train,'trainset')\n",
    "plot_histo_graphs(dataset.val,'valset')\n",
    "plot_histo_graphs(dataset.test,'testset')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(dataset.train))\n",
    "print(len(dataset.val))\n",
    "print(len(dataset.test))\n",
    "\n",
    "print(dataset.train[0])\n",
    "print(dataset.val[0])\n",
    "print(dataset.test[0])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time.time()\n",
    "\n",
    "with open('data/TSP/TSP.pkl','wb') as f:\n",
    "    pickle.dump([dataset.train,dataset.val,dataset.test],f)\n",
    "        \n",
    "print('Time (sec):',time.time() - start) # 58s\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test load function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET_NAME = 'TSP'\n",
    "dataset = LoadData(DATASET_NAME)  # 20s\n",
    "trainset, valset, testset = dataset.train, dataset.val, dataset.test\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time.time()\n",
    "\n",
    "batch_size = 10\n",
    "collate = TSPDataset.collate\n",
    "train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, collate_fn=collate)\n",
    "\n",
    "print('Time (sec):',time.time() - start)  # 0.0003s\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot TSP samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.spatial.distance import pdist, squareform\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import networkx as nx\n",
    "\n",
    "\n",
    "def _edges_to_node_pairs(W):\n",
    "    \"\"\"Helper function to convert edge matrix into pairs of adjacent nodes.\n",
    "    \"\"\"\n",
    "    pairs = []\n",
    "    for r in range(len(W)):\n",
    "        for c in range(len(W)):\n",
    "            if W[r][c] == 1:\n",
    "                pairs.append((r, c))\n",
    "    return pairs\n",
    "\n",
    "def plot_tsp(x_coord, W, tour):\n",
    "    \"\"\"\n",
    "    Helper function to plot TSP tours.\n",
    "    \n",
    "    Args:\n",
    "        x_coord: Coordinates of nodes\n",
    "        W: Graph as adjacency matrix\n",
    "        tour: Predicted tour\n",
    "        title: Title of figure/subplot\n",
    "    \"\"\"\n",
    "    W_val = squareform(pdist(x_coord, metric='euclidean'))\n",
    "    G = nx.from_numpy_matrix(W_val)\n",
    "    \n",
    "    pos = dict(zip(range(len(x_coord)), x_coord))\n",
    "    \n",
    "    adj_pairs = _edges_to_node_pairs(W)\n",
    "    \n",
    "    tour_pairs = []\n",
    "    for idx in range(len(tour)-1):\n",
    "        tour_pairs.append((tour[idx], tour[idx+1]))\n",
    "    tour_pairs.append((tour[idx+1], tour[0]))\n",
    "    \n",
    "    node_size = 50/(len(x_coord)//50)\n",
    "    \n",
    "    nx.draw_networkx_nodes(G, pos, node_color='b', node_size=node_size)\n",
    "    nx.draw_networkx_edges(G, pos, edgelist=adj_pairs, alpha=0.25, width=0.1)\n",
    "    nx.draw_networkx_edges(G, pos, edgelist=tour_pairs, alpha=1, width=1.5, edge_color='r')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = \"data/TSP/tsp50-500_test.txt\"\n",
    "file_data = open(filename, \"r\").readlines()\n",
    "num_neighbors = 25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for graph_idx, line in enumerate(file_data):\n",
    "    line = line.split(\" \")  # Split into list\n",
    "    num_nodes = int(line.index('output')//2)\n",
    "\n",
    "    # Convert node coordinates to required format\n",
    "    nodes_coord = []\n",
    "    for idx in range(0, 2 * num_nodes, 2):\n",
    "        nodes_coord.append([float(line[idx]), float(line[idx + 1])])\n",
    "\n",
    "    # Compute distance matrix\n",
    "    W_val = squareform(pdist(nodes_coord, metric='euclidean'))\n",
    "    # Determine k-nearest neighbors for each node\n",
    "    knns = np.argpartition(W_val, kth=num_neighbors, axis=-1)[:, num_neighbors::-1]\n",
    "    \n",
    "    W = np.zeros((num_nodes, num_nodes))\n",
    "    # Make connections \n",
    "    for idx in range(num_nodes):\n",
    "        W[idx][knns[idx]] = 1\n",
    "\n",
    "    # Convert tour nodes to required format\n",
    "    # Don't add final connection for tour/cycle\n",
    "    tour_nodes = [int(node) - 1 for node in line[line.index('output') + 1:-1]][:-1]\n",
    "\n",
    "    # Compute an edge adjacency matrix representation of tour\n",
    "    edges_target = np.zeros((num_nodes, num_nodes))\n",
    "    for idx in range(len(tour_nodes) - 1):\n",
    "        i = tour_nodes[idx]\n",
    "        j = tour_nodes[idx + 1]\n",
    "        edges_target[i][j] = 1\n",
    "        edges_target[j][i] = 1\n",
    "    # Add final connection of tour in edge target\n",
    "    edges_target[j][tour_nodes[0]] = 1\n",
    "    edges_target[tour_nodes[0]][j] = 1\n",
    "    \n",
    "    if num_nodes == 498:\n",
    "        print(num_nodes)\n",
    "        print(tour_nodes)\n",
    "\n",
    "        plt.figure(figsize=(5,5))\n",
    "        plot_tsp(nodes_coord, W, tour_nodes)\n",
    "        plt.savefig(f\"img/tsp{num_nodes}_{graph_idx}.pdf\", format='pdf', dpi=1200, bbox_inches='tight')\n",
    "        plt.show()\n",
    "        print(\"Stop (y/n)\")\n",
    "        if input() == 'y':\n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
